{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert dynamic computational graph to ONNX"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prepare environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install paddlepaddle paddle2onnx # required for paddle.onnx.export, paddlepaddle>=2.0.0, paddle2onnx>=0.3.2\n",
    "!pip install onnx onnxruntime # optional for check and run ONNX model, onnx>=1.7.0, onnxruntime>=1.5.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1、Build dynamic computational Network "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "from paddle import nn\n",
    "from paddle.static import InputSpec\n",
    "import paddle2onnx as p2o\n",
    "\n",
    "class LinearNet(nn.Layer):\n",
    "    def __init__(self):\n",
    "        super(LinearNet, self).__init__()\n",
    "        self._linear = nn.Linear(784, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self._linear(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2、Export model to ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2020-11-18 07:00:39 [INFO]\tONNX model saved in onnx.save/linear_net.onnx\n"
     ]
    }
   ],
   "source": [
    "# export to ONNX \n",
    "layer = LinearNet()\n",
    "save_path = 'onnx.save/linear_net'\n",
    "x_spec = InputSpec([None, 784], 'float32', 'x')\n",
    "layer.eval()\n",
    "paddle.onnx.export(layer, save_path, input_spec=[x_spec])\n",
    "# p2o.dygraph2onnx(layer, save_path+'.onnx', input_spec=[x_spec])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3、Verify the correctness of ONNX model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model is checked!\n"
     ]
    }
   ],
   "source": [
    "# check by ONNX\n",
    "import onnx\n",
    "\n",
    "onnx_file = save_path + '.onnx'\n",
    "onnx_model = onnx.load(onnx_file)\n",
    "onnx.checker.check_model(onnx_model)\n",
    "print('The model is checked!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4、Inference ONNX model by ONNXRuntime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exported model has been predict by ONNXRuntime!\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import onnxruntime\n",
    "\n",
    "x = np.random.random((2, 784)).astype('float32')\n",
    "\n",
    "# predict by ONNX Runtime\n",
    "ort_sess = onnxruntime.InferenceSession(onnx_file)\n",
    "ort_inputs = {ort_sess.get_inputs()[0].name: x}\n",
    "ort_outs = ort_sess.run(None, ort_inputs)\n",
    "\n",
    "print(\"Exported model has been predict by ONNXRuntime!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5、compare ONNX Runtime and Paddle results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The difference of result between ONNXRuntime and Paddle looks good!\n"
     ]
    }
   ],
   "source": [
    "# predict by Paddle\n",
    "paddle_outs = layer(x)\n",
    "\n",
    "# compare ONNX Runtime and Paddle results\n",
    "np.testing.assert_allclose(\n",
    "    ort_outs[0], paddle_outs.numpy(), rtol=1.0, atol=1e-05)\n",
    "\n",
    "print(\"The difference of result between ONNXRuntime and Paddle looks good!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
